# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A Langchain LLM component for connecting to Triton + TensorRT LLM backend."""

from __future__ import annotations

import queue
from functools import partial
from typing import Any, Dict, List, Optional

from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models import BaseLLM
from pydantic.v1 import Field, root_validator

from nemoguardrails.llm.providers.trtllm.client import TritonClient

STOP_WORDS = ["</s>"]
BAD_WORDS = [""]
RANDOM_SEED = 0


class TRTLLM(BaseLLM):
    """A custom Langchain LLM class that integrates with TRTLLM triton models.

    Arguments:
    server_url: (str) The URL of the Triton inference server to use.
    model_name: (str) The name of the Triton TRT model to use.
    temperature: (str) Temperature to use for sampling
    top_p: (float) The top-p value to use for sampling
    top_k: (float) The top k values use for sampling
    beam_width: (int) Last n number of tokens to penalize
    repetition_penalty: (int) Last n number of tokens to penalize
    length_penalty: (float) The penalty to apply repeated tokens
    tokens: (int) The maximum number of tokens to generate.
    client: The client object used to communicate with the inference server
    """

    server_url: str = Field(None, alias="server_url")

    # # all the optional arguments
    model_name: str = "ensemble"
    temperature: Optional[float] = 1.0
    top_p: Optional[float] = 0
    top_k: Optional[int] = 1
    tokens: Optional[int] = 100
    beam_width: Optional[int] = 1
    repetition_penalty: Optional[float] = 1.0
    length_penalty: Optional[float] = 1.0
    client: Any
    streaming: Optional[bool] = True

    @root_validator(allow_reuse=True)
    @classmethod
    def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
        """Validate that python package exists in environment."""
        try:
            values["client"] = TritonClient(values["server_url"])

        except ImportError as err:
            raise ImportError(
                "Could not import triton client python package. Please install it with `pip install tritonclient[all]`."
            ) from err
        return values

    @property
    def _get_model_default_parameters(self) -> Dict[str, Any]:
        return {
            "tokens": self.tokens,
            "top_k": self.top_k,
            "top_p": self.top_p,
            "temperature": self.temperature,
            "repetition_penalty": self.repetition_penalty,
            "length_penalty": self.length_penalty,
            "beam_width": self.beam_width,
        }

    @property
    def _invocation_params(self, **kwargs: Any) -> Dict[str, Any]:
        params = {**self._get_model_default_parameters, **kwargs}
        return params

    @property
    def _identifying_params(self) -> Dict[str, Any]:
        """Get all the identifying parameters."""
        return {
            "server_url": self.server_url,
            "model_name": self.model_name,
        }

    @property
    def _llm_type(self) -> str:
        return "trt_llm"

    def _call(
        self,
        prompt: str,
        stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None,
        **kwargs: Any,
    ) -> str:
        """
        Execute an inference request.

        Args:
            prompt: The prompt to pass into the model.
            stop: A list of strings to stop generation when encountered

        Returns:
            The string generated by the model
        """
        # pylint: disable-next=import-outside-toplevel
        from tritonclient.utils import InferenceServerException

        text_callback = None
        if run_manager:
            text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose)

        invocation_params = self._get_model_default_parameters
        invocation_params.update(kwargs)
        invocation_params["prompt"] = [[prompt]]
        model_params = self._identifying_params
        model_params.update(kwargs)

        result_queue: queue.Queue[Dict[str, str]] = queue.Queue()
        self.client.load_model(model_params["model_name"])
        self.client.request_streaming(model_params["model_name"], result_queue, **invocation_params)

        response = ""
        send_tokens = True
        while True:
            response_streaming = result_queue.get()

            if response_streaming is None or isinstance(response_streaming, InferenceServerException):
                self.client.close_streaming()
                break
            token = response_streaming["OUTPUT_0"]
            if token in STOP_WORDS:
                send_tokens = False
            if text_callback and send_tokens:
                if response_streaming["OUTPUT_0"] == "<0x0A>":
                    token = "\n"  # nosec
                text_callback(token)
                response = response + token
        return response

    async def _acall(self, *args, **kwargs):
        """Async version."""
        # TODO: use the async interface from the triton client
        return self._call(*args, **kwargs)
